{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# **RAG from Scratch**\n",
        "\n",
        "Authored by [Kalyan KS](https://www.linkedin.com/in/kalyanksnlp/). To stay updated with LLM, RAG and Agent updates, you can follow me on [Twitter](https://x.com/kalyan_kpl).\n",
        "\n",
        "- Step-1 : Extract text\n",
        "- Step-2 : Chunk the extracted text\n",
        "- Step-3 : Create a vector store with the chunks\n",
        "- Step-4 : Create a retriever which returns the relevant chunks\n",
        "- Step-5 : Build context from the relevant chunk texts\n",
        "- Step-6 : Build the RAG pipeline\n",
        "- Step-7 : Run the RAG pipeline to get the answer."
      ],
      "metadata": {
        "id": "hBldBXEn-s2G"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Install libraries**"
      ],
      "metadata": {
        "id": "ZExw2kIQ-vIu"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VeQwqlMC9mOE",
        "outputId": "38534f94-8d47-4f44-8f90-6e60077cf0b0"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/67.3 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.3/67.3 kB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m232.6/232.6 kB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m611.1/611.1 kB\u001b[0m \u001b[31m21.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.4/2.4 MB\u001b[0m \u001b[31m52.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.8/6.8 MB\u001b[0m \u001b[31m54.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m278.6/278.6 kB\u001b[0m \u001b[31m14.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m43.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m101.6/101.6 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.3/13.3 MB\u001b[0m \u001b[31m55.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m55.9/55.9 kB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m177.4/177.4 kB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m65.0/65.0 kB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m118.7/118.7 kB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m73.0/73.0 kB\u001b[0m \u001b[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m32.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.3/62.3 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m459.8/459.8 kB\u001b[0m \u001b[31m20.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m319.7/319.7 kB\u001b[0m \u001b[31m16.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.5/71.5 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.0/4.0 MB\u001b[0m \u001b[31m39.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m452.6/452.6 kB\u001b[0m \u001b[31m13.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m4.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Building wheel for pypika (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "!pip install -qU PyPDF2 chromadb litellm"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Set up the LLM API Key**"
      ],
      "metadata": {
        "id": "JGPGrBwB_UfQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import userdata\n",
        "import os\n",
        "os.environ['OPENAI_API_KEY'] = userdata.get(\"OPENAI_API_KEY\")"
      ],
      "metadata": {
        "id": "59BAtNsw_WX0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Extract Text**"
      ],
      "metadata": {
        "id": "BItWSfV5_YEP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from PyPDF2 import PdfReader"
      ],
      "metadata": {
        "id": "wl1GlnRz_kv3"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from typing import List\n",
        "from PyPDF2 import PdfReader\n",
        "\n",
        "def text_extract(pdf_path: str) -> str:\n",
        "    \"\"\"\n",
        "    Extracts text from all pages of a given PDF file.\n",
        "\n",
        "    Args:\n",
        "        pdf_path (str): Path to the PDF file.\n",
        "\n",
        "    Returns:\n",
        "        str: Extracted text from the PDF, concatenated with newline separators.\n",
        "    \"\"\"\n",
        "\n",
        "    # An empty list to store extracted text from PDF pages\n",
        "    pdf_pages = []\n",
        "\n",
        "    # Open the PDF file in binary read mode\n",
        "    with open(pdf_path, 'rb') as file:\n",
        "\n",
        "        # Create a PdfReader object to read the PDF\n",
        "        pdf_reader = PdfReader(file)\n",
        "\n",
        "        # Iterate through all pages in the PDF\n",
        "        for page in pdf_reader.pages:\n",
        "\n",
        "            # Extract text from the current page\n",
        "            text = page.extract_text()\n",
        "\n",
        "            # Append the extracted text to the list\n",
        "            pdf_pages.append(text)\n",
        "\n",
        "    # Join all extracted text using newline separator\n",
        "    pdf_text = \"\\n\".join(pdf_pages)\n",
        "\n",
        "    # Return the extracted text as a single string\n",
        "    return pdf_text\n"
      ],
      "metadata": {
        "id": "j9ud4GFX_fMz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Download the PDF file\n",
        "import requests\n",
        "\n",
        "pdf_url = 'https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf'\n",
        "response = requests.get(pdf_url)\n",
        "\n",
        "pdf_path = 'attention_is_all_you_need.pdf'\n",
        "with open(pdf_path, 'wb') as file:\n",
        "    file.write(response.content)"
      ],
      "metadata": {
        "id": "ttlt3RYWAh8n"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "pdf_text = text_extract(pdf_path)"
      ],
      "metadata": {
        "id": "_NK7uS8aAwxC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(pdf_text[:300])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XIuE4C3WBCw7",
        "outputId": "3168bc37-5076-4735-cf7e-4e257f465339"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Attention Is All You Need\n",
            "Ashish Vaswani\u0003\n",
            "Google Brain\n",
            "avaswani@google.comNoam Shazeer\u0003\n",
            "Google Brain\n",
            "noam@google.comNiki Parmar\u0003\n",
            "Google Research\n",
            "nikip@google.comJakob Uszkoreit\u0003\n",
            "Google Research\n",
            "usz@google.com\n",
            "Llion Jones\u0003\n",
            "Google Research\n",
            "llion@google.comAidan N. Gomez\u0003y\n",
            "University of Toronto\n",
            "aidan@c\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Chunk Text**\n",
        "\n"
      ],
      "metadata": {
        "id": "6gFK0HfYBb5i"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from typing import List\n",
        "import re\n",
        "from collections import deque\n",
        "\n",
        "\n",
        "def text_chunk(text: str, max_length: int = 1000) -> List[str]:\n",
        "    \"\"\"\n",
        "    Splits a given text into chunks while ensuring that sentences remain intact.\n",
        "\n",
        "    The function maintains sentence boundaries by splitting based on punctuation\n",
        "    (. ! ?) and attempts to fit as many sentences as possible within `max_length`\n",
        "    per chunk.\n",
        "\n",
        "    Args:\n",
        "        text (str): The input text to be chunked.\n",
        "        max_length (int, optional): Maximum length of each chunk. Default is 1000.\n",
        "\n",
        "    Returns:\n",
        "        List[str]: A list of text chunks, each containing full sentences.\n",
        "    \"\"\"\n",
        "\n",
        "    # Split text into sentences while ensuring punctuation (. ! ?) stays at the end\n",
        "    sentences = deque(re.split(r'(?<=[.!?])\\s+', text.replace('\\n', ' ')))\n",
        "\n",
        "    # An empty list to store the final chunks\n",
        "    chunks = []\n",
        "\n",
        "    # Temporary string to hold the current chunk\n",
        "    chunk_text = \"\"\n",
        "\n",
        "    while sentences:\n",
        "        # Access sentence from the deque and strip any extra spaces\n",
        "        sentence = sentences.popleft().strip()\n",
        "\n",
        "        # Check if the sentence is non-empty before processing\n",
        "        if sentence:\n",
        "            # If adding this sentence exceeds max_length and chunk_text is not empty, store the current chunk\n",
        "            if len(chunk_text) + len(sentence) > max_length and chunk_text:\n",
        "\n",
        "                # Save the current chunk\n",
        "                chunks.append(chunk_text)\n",
        "\n",
        "                # Start a new chunk with the current sentence\n",
        "                chunk_text = sentence\n",
        "            else:\n",
        "                # Append the sentence to the current chunk with a space\n",
        "                chunk_text += \" \" + sentence\n",
        "\n",
        "    # Add the last chunk if there's any remaining text\n",
        "    if chunk_text:\n",
        "        chunks.append(chunk_text)\n",
        "\n",
        "    return chunks"
      ],
      "metadata": {
        "id": "wsYaJ2xAQ-9D"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "chunks = text_chunk(pdf_text)"
      ],
      "metadata": {
        "id": "stO6tscmToxv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(f\"Number of chunks ={len(chunks)}\")\n",
        "print(chunks[0])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OrPVFAyXTsaQ",
        "outputId": "d71e16f7-b72b-424a-8aec-b2bbfff4c382"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Number of chunks =36\n",
            " Attention Is All You Need Ashish Vaswani\u0003 Google Brain avaswani@google.comNoam Shazeer\u0003 Google Brain noam@google.comNiki Parmar\u0003 Google Research nikip@google.comJakob Uszkoreit\u0003 Google Research usz@google.com Llion Jones\u0003 Google Research llion@google.comAidan N. Gomez\u0003y University of Toronto aidan@cs.toronto.eduŁukasz Kaiser\u0003 Google Brain lukaszkaiser@google.com Illia Polosukhin\u0003z illia.polosukhin@gmail.com Abstract The dominant sequence transduction models are based on complex recurrent or convolutional neural networks that include an encoder and a decoder. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring signiﬁcantly less time to train.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Create the Vector Store**"
      ],
      "metadata": {
        "id": "zyfi5blWT-cn"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Set up Chromadb\n",
        "import chromadb\n",
        "from chromadb.utils import embedding_functions\n",
        "from chromadb.api.models import Collection\n",
        "\n",
        "def create_vector_store(db_path: str) -> Collection:\n",
        "    \"\"\"\n",
        "    Creates a persistent ChromaDB vector store with OpenAI embeddings.\n",
        "\n",
        "    Args:\n",
        "        db_path (str): Path where the ChromaDB database will be stored.\n",
        "\n",
        "    Returns:\n",
        "        Collection: A ChromaDB collection object for storing and retrieving embedded vectors.\n",
        "    \"\"\"\n",
        "\n",
        "    # Initialize a ChromaDB PersistentClient with the specified database path\n",
        "    client = chromadb.PersistentClient(path=db_path)\n",
        "\n",
        "    # Create an embedding function using OpenAI's text embedding model\n",
        "    embeddings = embedding_functions.OpenAIEmbeddingFunction(\n",
        "        api_key=userdata.get(\"OPENAI_API_KEY\"),  # Retrieve API key from user data\n",
        "        model_name=\"text-embedding-3-small\"  # Specify the embedding model\n",
        "    )\n",
        "\n",
        "    # Create a new collection in the ChromaDB database with the embedding function\n",
        "    db = client.create_collection(\n",
        "        name=\"pdf_chunks\",  # Name of the collection where embeddings will be stored\n",
        "        embedding_function=embeddings  # Apply the embedding function\n",
        "    )\n",
        "\n",
        "    # Return the created ChromaDB collection\n",
        "    return db\n"
      ],
      "metadata": {
        "id": "G36-SWrFUANV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Insert chunks into vector store\n",
        "import os\n",
        "import uuid\n",
        "\n",
        "def insert_chunks_vectordb(chunks: List[str], db: Collection, file_path: str) -> None:\n",
        "    \"\"\"\n",
        "    Inserts text chunks into a ChromaDB vector store with metadata.\n",
        "\n",
        "    Args:\n",
        "        chunks (List[str]): List of text chunks to be stored.\n",
        "        db (Collection): The ChromaDB collection where the chunks will be inserted.\n",
        "        file_path (str): Path of the source file for metadata.\n",
        "\n",
        "    Returns:\n",
        "        None\n",
        "    \"\"\"\n",
        "\n",
        "    # Extract the file name from the given file path\n",
        "    file_name = os.path.basename(file_path)\n",
        "\n",
        "    # Generate unique IDs for each chunk\n",
        "    id_list = [str(uuid.uuid4()) for _ in range(len(chunks))]\n",
        "\n",
        "    # Create metadata for each chunk, storing the chunk index and source file name\n",
        "    metadata_list = [{\"chunk\": i, \"source\": file_name} for i in range(len(chunks))]\n",
        "\n",
        "    # Define batch size for inserting chunks to optimize performance\n",
        "    batch_size = 40\n",
        "\n",
        "    # Insert chunks into the database in batches\n",
        "    for i in range(0, len(chunks), batch_size):\n",
        "        end_id = min(i + batch_size, len(chunks))  # Ensure we don't exceed list length\n",
        "\n",
        "        # Add the batch of chunks to the vector store\n",
        "        db.add(\n",
        "            documents=chunks[i:end_id],\n",
        "            metadatas=metadata_list[i:end_id],\n",
        "            ids=id_list[i:end_id]\n",
        "        )\n",
        "\n",
        "    print(f\"{len(chunks)} chunks added to the vector store\")\n"
      ],
      "metadata": {
        "id": "Mz5u2hy-UeUm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Retrieve Chunks**"
      ],
      "metadata": {
        "id": "duNeKQucZBnW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from typing import Any, List\n",
        "\n",
        "def retrieve_chunks(db: Collection, query: str, n_results: int = 2) -> List[Any]:\n",
        "    \"\"\"\n",
        "    Retrieves relevant chunks from the  vector store for the given query.\n",
        "\n",
        "    Args:\n",
        "        db (Collection): The vector store object\n",
        "        query (str): The search query text.\n",
        "        n_results (int, optional): The number of relevant chunks to retrieve. Defaults to 2.\n",
        "\n",
        "    Returns:\n",
        "        List[Any]: A list of relevant chunks retrieved from the vector store.\n",
        "    \"\"\"\n",
        "\n",
        "    # Perform a query on the database to get the most relevant chunks\n",
        "    relevant_chunks = db.query(query_texts=[query], n_results=n_results)\n",
        "\n",
        "    # Return the retrieved relevant chunks\n",
        "    return relevant_chunks\n"
      ],
      "metadata": {
        "id": "OiLTV_9jZGCy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Build Context**"
      ],
      "metadata": {
        "id": "arKLWO02ZrrF"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "def build_context(relevant_chunks) -> str:\n",
        "    \"\"\"\n",
        "    Builds a single context string by combining texts from relevant chunks.\n",
        "\n",
        "    Args:\n",
        "        relevant_chunks: relevant chunks retrieved from the vector store.\n",
        "\n",
        "    Returns:\n",
        "        str: A single string containing all document chunks combined with newline separators.\n",
        "    \"\"\"\n",
        "\n",
        "    # combine the text from relevant chunks with newline separator\n",
        "    context = \"\\n\".join(relevant_chunks['documents'][0])\n",
        "\n",
        "    # Return the combined context string\n",
        "    return context\n"
      ],
      "metadata": {
        "id": "Re7I3y4iZwuz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Build RAG Pipeline**"
      ],
      "metadata": {
        "id": "j-X0x6gtbp2p"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "from typing import Tuple\n",
        "\n",
        "def get_context(pdf_path: str, query: str, db_path: str) -> Tuple[str, str]:\n",
        "    \"\"\"\n",
        "    Retrieves the relevant chunks from the vector store and then builds context from them.\n",
        "\n",
        "    Args:\n",
        "        pdf_path (str): The file path to the PDF document.\n",
        "        query (str): The query string to search within the vector store.\n",
        "        db_path (str): The file path to the persistent vector store database.\n",
        "\n",
        "    Returns:\n",
        "        Tuple[str, str]: A tuple containing the context related to the query and the original query string.\n",
        "    \"\"\"\n",
        "\n",
        "    # Check if the vector store already exists\n",
        "    if os.path.exists(db_path):\n",
        "        print(\"Loading existing vector store...\")\n",
        "\n",
        "        # Initialize the persistent client for the existing database\n",
        "        client = chromadb.PersistentClient(path=db_path)\n",
        "\n",
        "        # Create the embedding function using OpenAI embeddings\n",
        "        embeddings = embedding_functions.OpenAIEmbeddingFunction(\n",
        "            api_key=userdata.get(\"OPENAI_API_KEY\"),  # Fetch API key from userdata\n",
        "            model_name=\"text-embedding-3-small\"      # Specify the embedding model\n",
        "        )\n",
        "\n",
        "        # Get the collection of PDF chunks from the existing vector store\n",
        "        db = client.get_collection(name=\"pdf_chunks\", embedding_function=embeddings)\n",
        "    else:\n",
        "        print(\"Creating new vector store...\")\n",
        "\n",
        "        # Extract text from the provided PDF\n",
        "        pdf_text = text_extract(pdf_path)\n",
        "\n",
        "        # Chunk the extracted text\n",
        "        chunks = text_chunk(pdf_text)\n",
        "\n",
        "        # Create a new vector store\n",
        "        db = create_vector_store(db_path)\n",
        "\n",
        "        # Insert the text chunks into the vector store\n",
        "        insert_chunks_vectordb(chunks, db, pdf_path)\n",
        "\n",
        "    # Retrieve the relevant chunks based on the query\n",
        "    relevant_chunks = retrieve_chunks(db, query)\n",
        "\n",
        "    # Build the context from the relevant chunks\n",
        "    context = build_context(relevant_chunks)\n",
        "\n",
        "    # Return the context and the original query\n",
        "    return context, query\n"
      ],
      "metadata": {
        "id": "azZAY-BraBwf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "def get_prompt(context: str, query: str) -> str:\n",
        "    \"\"\"\n",
        "    Generates a rag prompt based on the given context and query.\n",
        "\n",
        "    Args:\n",
        "        context (str): The context the LLM should use to answer the question.\n",
        "        query (str): The user query that needs to be answered based on the context.\n",
        "\n",
        "    Returns:\n",
        "        str: The generated rag prompt.\n",
        "    \"\"\"\n",
        "\n",
        "    # Format the prompt with the provided context and query\n",
        "    rag_prompt = f\"\"\" You are an AI model trained for question answering. You should answer the\n",
        "    given question based on the given context only.\n",
        "    Question : {query}\n",
        "    \\n\n",
        "    Context : {context}\n",
        "    \\n\n",
        "    If the answer is not present in the given context, respond as: The answer to this question is not available\n",
        "    in the provided content.\n",
        "    \"\"\"\n",
        "\n",
        "    # Return the formatted prompt\n",
        "    return rag_prompt\n"
      ],
      "metadata": {
        "id": "gzAC5TswbrnO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from litellm import completion\n",
        "\n",
        "def get_response(rag_prompt: str) -> str:\n",
        "    \"\"\"\n",
        "    Sends a prompt to the OpenAI LLM and returns the answer.\n",
        "\n",
        "    Args:\n",
        "        rag_prompt (str): The rag prompt.\n",
        "\n",
        "    Returns:\n",
        "        str: The LLM generated answer.\n",
        "    \"\"\"\n",
        "    # Specify the LLM to use\n",
        "    model = \"openai/gpt-4o-mini\"\n",
        "\n",
        "    # Prepare the message to be sent to the model\n",
        "    messages = [{\"role\": \"user\", \"content\": rag_prompt}]\n",
        "\n",
        "    # Call the completion function to get a response from the model\n",
        "    response = completion(model=model, messages=messages, temperature=0)\n",
        "\n",
        "    # Return the answer\n",
        "    answer = response.choices[0].message.content\n",
        "    return answer\n"
      ],
      "metadata": {
        "id": "8LxkJY_gcd5g"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "def rag_pipeline(pdf_path: str, query: str, db_path: str) -> str:\n",
        "    \"\"\"\n",
        "    Runs a Retrieval-Augmented Generation (RAG) pipeline to retrieve context from a vector store,\n",
        "    generate the rag prompt, and then get the answer from the model.\n",
        "\n",
        "    Args:\n",
        "        pdf_path (str): The file path to the PDF document from which context is extracted.\n",
        "        query (str): The query for which a response is needed, based on the context.\n",
        "        db_path (str): The file path to the persistent vector store database used for context retrieval.\n",
        "\n",
        "    Returns:\n",
        "        str: The model's response based on the context and the provided query.\n",
        "    \"\"\"\n",
        "\n",
        "    # get the context\n",
        "    context, query = get_context(pdf_path, query, db_path)\n",
        "\n",
        "    # Generate the rag prompt based on the context and query\n",
        "    rag_prompt = get_prompt(context, query)\n",
        "\n",
        "    # Get the response from the model using the rag prompt\n",
        "    response = get_response(rag_prompt)\n",
        "\n",
        "    # Return the model's response\n",
        "    return response\n"
      ],
      "metadata": {
        "id": "GQLzNtoWd3pd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Run RAG Pipeline**"
      ],
      "metadata": {
        "id": "LdIxhjaBn6d4"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Set the chroma DB path\n",
        "current_dir = \"/content/rag\"\n",
        "persistent_directory = os.path.join(current_dir, \"db\", \"chroma_db_pdf\")\n",
        "\n",
        "# PDF path\n",
        "pdf_path = \"/content/attention_is_all_you_need.pdf\"\n",
        "\n",
        "# RAG query\n",
        "query = \"What is the transformer architecture?\""
      ],
      "metadata": {
        "id": "zyAPVswXegm8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Run the RAG pipeline\n",
        "answer = rag_pipeline(pdf_path, query, persistent_directory)"
      ],
      "metadata": {
        "id": "bZcKzayselVs",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "407c7a31-e8df-42f7-b853-16dde0ee2593"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Creating new vector store...\n",
            "36 chunks added to the vector store\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(f\"Query:{query}\")\n",
        "print(f\"Generated answer:{answer}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "djQeTMMNeoDB",
        "outputId": "1fa18a3f-8178-4b70-c407-9cc2f5f24e77"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Query:What is the transformer architecture?\n",
            "Generated answer:The Transformer architecture is a model that relies entirely on an attention mechanism to draw global dependencies between input and output, eschewing recurrence. It consists of stacked self-attention and point-wise, fully connected layers for both the encoder and decoder. The encoder is composed of a stack of six identical layers, each with a multi-head self-attention mechanism and a position-wise fully connected feed-forward network, along with residual connections and layer normalization. This architecture allows for significant parallelization and achieves state-of-the-art translation quality with relatively short training times.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# RAG query\n",
        "query = \"What is self-attention?\"\n",
        "\n",
        "# Run the RAG pipeline\n",
        "answer = rag_pipeline(pdf_path, query, persistent_directory)\n",
        "\n",
        "print(f\"Query:{query}\")\n",
        "print(f\"Generated answer:{answer}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_XJ4ulaInq48",
        "outputId": "81675e73-da48-4351-8832-0ebc77400584"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loading existing vector store...\n",
            "Query:What is self-attention?\n",
            "Generated answer:Self-attention, sometimes called intra-attention, is an attention mechanism that relates different positions of a single sequence in order to compute a representation of that sequence. It has been successfully used in various tasks, including reading comprehension, abstractive summarization, textual entailment, and learning task-independent sentence representations.\n"
          ]
        }
      ]
    }
  ]
}